{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# LongLLaMA: Focused Transformer Training for Context Scaling\n",
        "**LongLLaMA is a large language model capable of handling long contexts of 256k tokens or even more**.\n",
        "\n",
        "It is built upon the foundation of [OpenLLaMA](https://github.com/openlm-research/open_llama) and fine-tuned using the [Focused Transformer (FoT)](https://arxiv.org/abs/2307.03170) method.\n",
        "\n",
        "This notebook is a demo of [LongLLaMA-Instruct-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_instruct), a [LongLLaMA-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_v1_1) fine-tuned using [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [ShareGPT-Processed](https://huggingface.co/datasets/zetavg/ShareGPT-Processed) datasets. Note that LongLLaMA-Instruct 3B is licensed differently due to the use of responses from GPT-4/GPT-3.5.\n",
        "Similarly to the [LongLLaMA 3B](https://huggingface.co/syzymon/long_llama_3b), the model weights can serve as the drop-in replacement of LLaMA in existing implementations (for short context up to 2048 tokens).\n",
        "\n",
        "This notebook is a research preview of [LongLLaMA-Instruct-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_instruct).\n",
        "For more, see the [FoT paper](https://arxiv.org/abs/2307.03170) and [GitHub repository](https://github.com/CStanKonrad/long_llama)."
      ],
      "metadata": {
        "id": "blR8hhA_e-4S"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Initial steps"
      ],
      "metadata": {
        "id": "ZTROtgpafB_8"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "La3cPHfCe7hU"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade pip\n",
        "!pip install transformers==4.30.0  sentencepiece accelerate -q"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "from transformers import LlamaTokenizer, AutoModelForCausalLM, TextStreamer, PreTrainedModel, PreTrainedTokenizer\n",
        "from typing import List, Optional"
      ],
      "metadata": {
        "id": "0x_utNYxfECA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "MODEL_PATH = (\n",
        "    \"syzymon/long_llama_3b_instruct\"\n",
        ")\n",
        "TOKENIZER_PATH = MODEL_PATH\n",
        "# to fit into colab GPU we will use reduced precision\n",
        "TORCH_DTYPE = torch.bfloat16\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    device = torch.device(\"cuda\")\n",
        "else:\n",
        "    device = torch.device(\"cpu\")"
      ],
      "metadata": {
        "id": "LrrFlXKMfHMs"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    MODEL_PATH,\n",
        "    torch_dtype=TORCH_DTYPE,\n",
        "    device_map=device,\n",
        "    trust_remote_code=True,\n",
        "    # mem_attention_grouping is used\n",
        "    # to trade speed for memory usage\n",
        "    # for details, see the section Additional configuration\n",
        "    mem_attention_grouping=(1, 2048),\n",
        ")\n",
        "model.eval()"
      ],
      "metadata": {
        "id": "5fE5Z1ABfJUp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# The demo"
      ],
      "metadata": {
        "id": "ElblcG7MfOM4"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Summarization and question answering\n",
        "We used the [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) dataset to instruction tune the model.\n",
        "Here we show the ability of the model to answer questions about long documents.\n",
        "Note that as the model used in the demo has only 3B parameters, it can have trouble understanding complex documents."
      ],
      "metadata": {
        "id": "chPK2r8CfRcC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import urllib.request\n",
        "import tempfile\n",
        "import shutil\n",
        "import os\n",
        "\n",
        "\n",
        "def get_paper(url: str, main_file: Optional[str] = None):\n",
        "    with tempfile.TemporaryDirectory() as tmp_dir:\n",
        "        file_path = os.path.join(tmp_dir, \"_paper.tmp\")\n",
        "        if main_file is not None:\n",
        "            # we are dealing with an archive\n",
        "            file_path += \".tar\"\n",
        "\n",
        "        urllib.request.urlretrieve(url, file_path)\n",
        "\n",
        "        if main_file is not None:\n",
        "            # we are dealing with an archive\n",
        "            shutil.unpack_archive(file_path, tmp_dir)\n",
        "            main_file_path = os.path.join(tmp_dir, main_file)\n",
        "        else:\n",
        "            main_file_path = file_path\n",
        "\n",
        "        with open(main_file_path, \"r\") as f:\n",
        "            data = f.read()\n",
        "\n",
        "    return data\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def load_to_memory(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text: str):\n",
        "    tokenized_data = tokenizer(text, return_tensors=\"pt\")\n",
        "    input_ids = tokenized_data.input_ids\n",
        "    input_ids = input_ids.to(model.device)\n",
        "    torch.manual_seed(0)\n",
        "    output = model(input_ids=input_ids)\n",
        "    memory = output.past_key_values\n",
        "    return memory\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def generate_with_memory(\n",
        "    model: PreTrainedModel, tokenizer: PreTrainedTokenizer, memory, prompt: str, temperature=0.6\n",
        "):\n",
        "    tokenized_data = tokenizer(prompt, return_tensors=\"pt\")\n",
        "    input_ids = tokenized_data.input_ids\n",
        "    input_ids = input_ids.to(model.device)\n",
        "\n",
        "    streamer = TextStreamer(tokenizer, skip_prompt=False)\n",
        "\n",
        "    new_memory = memory\n",
        "\n",
        "    stop = False\n",
        "    while not stop:\n",
        "        output = model(input_ids, past_key_values=new_memory)\n",
        "        new_memory = output.past_key_values\n",
        "        assert len(output.logits.shape) == 3\n",
        "        assert output.logits.shape[0] == 1\n",
        "        last_logit = output.logits[[0], [-1], :]\n",
        "        dist = torch.distributions.Categorical(logits=last_logit / temperature)\n",
        "        next_token = dist.sample()\n",
        "        if next_token[0] == tokenizer.eos_token_id:\n",
        "            streamer.put(next_token[None, :])\n",
        "            streamer.end()\n",
        "            stop = True\n",
        "        else:\n",
        "            input_ids = next_token[None, :]\n",
        "            streamer.put(input_ids)\n",
        "\n",
        "\n",
        "PROMPT_PREFIX = \"You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can.\\n\"\n",
        "\n",
        "\n",
        "def construct_question_prompt(question: str):\n",
        "    prompt = f\"\\nAnswer the following question breifly using information from the text above.\\nQuestion: {question}\\nAnswer: \"\n",
        "    return prompt\n",
        "\n",
        "\n",
        "def ask_model(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str, memory, seed=0):\n",
        "    tokenized_data = tokenizer(prompt, return_tensors=\"pt\")\n",
        "    input_ids = tokenized_data.input_ids\n",
        "    input_ids = input_ids.to(model.device)\n",
        "\n",
        "    torch.manual_seed(seed)\n",
        "    generate_with_memory(model, tokenizer, memory, prompt)"
      ],
      "metadata": {
        "id": "rt0PMfKDfLjy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fot_paper = get_paper(url=\"https://raw.githubusercontent.com/CStanKonrad/long_llama/main/assets/fot_paper.tar\", main_file=\"fot_paper.tex\")\n",
        "fot_memory = load_to_memory(model, tokenizer, PROMPT_PREFIX + fot_paper)"
      ],
      "metadata": {
        "id": "14ao1rk6faHM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = construct_question_prompt(\"What is the paper above about? Summarize it briefly.\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ],
      "metadata": {
        "id": "uz8Gja1Nfdkb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = construct_question_prompt(\"What method is introduced in the paper?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ],
      "metadata": {
        "id": "3tQfPSURfejw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = construct_question_prompt(\"How is the 3B model called by the authors?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ],
      "metadata": {
        "id": "s0xg6aBAfgDG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = construct_question_prompt(\"Name at least one author of the presented paper.\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ],
      "metadata": {
        "id": "sd3wPEzAfhFF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = construct_question_prompt(\"What is the distraction issue?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ],
      "metadata": {
        "id": "6-mDT8TufjCl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Chat\n",
        "We have also used [ShareGPT-Processed](https://huggingface.co/datasets/zetavg/ShareGPT-Processed) dataset to enhance the model conversation abilities. The chat prompt was inspired by [LongChat](https://github.com/DachengLi1/LongChat)."
      ],
      "metadata": {
        "id": "QPTv8WGAfklZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class ChatOutputBuffer:\n",
        "    \"\"\"\n",
        "    For providing online output that\n",
        "    is truncated after generating specified (stop_text)\n",
        "    sequence of characters\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, stop_text: List[str], tokenizer: PreTrainedModel):\n",
        "        self.tokenizer = tokenizer\n",
        "        self.streamer = TextStreamer(tokenizer, skip_prompt=False)\n",
        "        self.max_stop_seq = 0\n",
        "        self.stop_seq = []\n",
        "        for st in stop_text:\n",
        "            self.stop_seq.append(st)\n",
        "            self.max_stop_seq = max(self.max_stop_seq, len(st))\n",
        "\n",
        "        self.output_buffer = np.empty((0,), dtype=np.int64)\n",
        "\n",
        "    def reset_output_buffer(self):\n",
        "        self.output_buffer = np.empty((0,), dtype=np.int64)\n",
        "\n",
        "    def advance_output(self):\n",
        "        beg = 0\n",
        "        end = len(self.output_buffer) - self.max_stop_seq\n",
        "\n",
        "        if end > beg:\n",
        "            output = self.output_buffer[beg:end]\n",
        "            self.streamer.put(output)\n",
        "            self.output_buffer = self.output_buffer[end:]\n",
        "\n",
        "    def flush_buffer(self):\n",
        "        if len(self.output_buffer) > 0:\n",
        "            self.streamer.put(self.output_buffer)\n",
        "            self.output_buffer = self.output_buffer[len(self.output_buffer) :]\n",
        "        self.streamer.end()\n",
        "\n",
        "    def generation_too_long(self, text: str) -> int:\n",
        "        end_requests = 0\n",
        "        for st in self.stop_seq:\n",
        "            if text.endswith(st):\n",
        "                end_requests += 1\n",
        "        return end_requests\n",
        "\n",
        "    def update_buffer(self, next_tok: int) -> bool:\n",
        "        assert isinstance(next_tok, int)\n",
        "\n",
        "        array_next_tok = np.array([next_tok], dtype=np.int64)\n",
        "        self.output_buffer = np.concatenate([self.output_buffer, array_next_tok], axis=0)\n",
        "\n",
        "        suffix = self.output_buffer[-self.max_stop_seq :]\n",
        "        decoded = self.tokenizer.decode(suffix)\n",
        "        end_requests = self.generation_too_long(decoded)\n",
        "        if end_requests > 0:\n",
        "            decoded = self.tokenizer.decode(suffix[1:])\n",
        "            while self.generation_too_long(decoded) == end_requests:\n",
        "                suffix = suffix[1:]\n",
        "                decoded = self.tokenizer.decode(suffix[1:])\n",
        "\n",
        "            left_intact = len(self.output_buffer) - len(suffix)\n",
        "\n",
        "            self.output_buffer = self.output_buffer[:left_intact]\n",
        "            self.flush_buffer()\n",
        "            return True\n",
        "\n",
        "        self.advance_output()\n",
        "        return False\n",
        "\n",
        "\n",
        "class SimpleChatBot:\n",
        "    def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):\n",
        "        self.model = model\n",
        "        self.tokenizer = tokenizer\n",
        "        self.prompt = \"A chat between a user (denoted as USER:) and an artificial intelligence assistant (denoted as ASSISTANT:). The assistant gives helpful, detailed, and polite answers to the user's questions.\\n\\n\"\n",
        "        self.tokenized_prompt = self.tokenizer.encode(self.prompt, return_tensors=\"pt\", add_special_tokens=False)\n",
        "        self.tokenized_prompt = torch.concatenate(\n",
        "            [torch.tensor([self.tokenizer.bos_token_id], dtype=torch.long).reshape(1, 1), self.tokenized_prompt],\n",
        "            dim=-1,\n",
        "        )\n",
        "        self.model_name = \"\\nASSISTANT: \"\n",
        "        self.tokenized_model_name = self.tokenizer.encode(\n",
        "            self.model_name, return_tensors=\"pt\", add_special_tokens=False\n",
        "        )\n",
        "        self.user_name = \"\\nUSER: \"\n",
        "        self.tokenized_user_name = self.tokenizer.encode(self.user_name, return_tensors=\"pt\", add_special_tokens=False)\n",
        "        self.past_key_values = None\n",
        "\n",
        "        self.t = 0.8\n",
        "        self.output_buffer = ChatOutputBuffer(\n",
        "            [self.model_name.strip(), self.user_name.strip(), self.tokenizer.eos_token], self.tokenizer\n",
        "        )\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def ask(self, text: str):\n",
        "        input_ids = self.tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=False)\n",
        "\n",
        "        input_ids = torch.concatenate([self.tokenized_user_name, input_ids, self.tokenized_model_name], dim=-1)\n",
        "\n",
        "        if self.past_key_values is None:\n",
        "            input_ids = torch.concatenate([self.tokenized_prompt, input_ids], dim=-1)\n",
        "\n",
        "        self.output_buffer.reset_output_buffer()\n",
        "        output_text = self.model_name\n",
        "        output_ids = self.tokenizer.encode(\n",
        "            output_text, return_tensors=\"pt\", add_special_tokens=self.past_key_values is None\n",
        "        )\n",
        "        self.output_buffer.streamer.put(output_ids)\n",
        "\n",
        "        is_writing = True\n",
        "\n",
        "        while is_writing:\n",
        "            input_ids = input_ids.to(model.device)\n",
        "            output = self.model(input_ids=input_ids, past_key_values=self.past_key_values)\n",
        "\n",
        "            logits = output.logits\n",
        "            assert len(logits.shape) == 3\n",
        "            assert logits.shape[0] == 1\n",
        "            last_logit = logits[[0], [-1], :]\n",
        "\n",
        "            dist = torch.distributions.Categorical(logits=last_logit / self.t)\n",
        "            next_token = dist.sample()\n",
        "            # Note that parts of cut out text may remain in model memory\n",
        "            # this is implemented in this way for performance reasons\n",
        "            past_key_values = output.past_key_values\n",
        "            assert len(next_token.shape) == 1\n",
        "            should_stop = self.output_buffer.update_buffer(next_token[0].cpu().item())\n",
        "            if should_stop:\n",
        "                is_writing = False\n",
        "            else:\n",
        "                input_ids = next_token[None, :]\n",
        "                self.past_key_values = past_key_values\n"
      ],
      "metadata": {
        "id": "dDzevraXflyn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Feel free to try the chat yourself:"
      ],
      "metadata": {
        "id": "UYaj4mlIf7gi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "chatbot = SimpleChatBot(model=model, tokenizer=tokenizer)\n",
        "while True:\n",
        "    user_text = input(\"USER: \")\n",
        "    chatbot.ask(user_text)"
      ],
      "metadata": {
        "id": "qjd33BLlfph5"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}